{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "掩码模型训练",
   "id": "72f0c2961d5314b1"
  },
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-05-21T13:40:58.133317Z",
     "start_time": "2025-05-21T13:40:53.789700Z"
    }
   },
   "source": [
    "from datasets import load_dataset\n",
    "from MyHelper import *\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer, DataCollatorForWholeWordMask\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\51165\\.conda\\envs\\e12\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "加载数据集",
   "id": "3926ff3a848572dd"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T13:40:59.714176Z",
     "start_time": "2025-05-21T13:40:58.233765Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model_name = Config.hfl_chinese_macbert_base\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForMaskedLM.from_pretrained(model_name)"
   ],
   "id": "8454a1d4f5d6abb1",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-macbert-base were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T13:41:05.288809Z",
     "start_time": "2025-05-21T13:40:59.727180Z"
    }
   },
   "cell_type": "code",
   "source": [
    "dataset_name = \"pleisto/wikipedia-cn-20230720-filtered\"\n",
    "ds = load_dataset(dataset_name)\n",
    "ds, ds.get(\"train\")[0]"
   ],
   "id": "a26b1c503b94eab9",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DatasetDict({\n",
       "     train: Dataset({\n",
       "         features: ['completion', 'source'],\n",
       "         num_rows: 254547\n",
       "     })\n",
       " }),\n",
       " {'completion': '昭通机场（ZPZT）是位于中国云南昭通的民用机场，始建于1935年，1960年3月开通往返航班“昆明－昭通”，原来属军民合用机场。1986年机场停止使用。1991年11月扩建，于1994年2月恢复通航。是西南地区「文明机场」，通航城市昆明。 机场占地1957亩，飞行区等级为4C，有一条跑道，长2720米，宽48米，可供波音737及以下机型起降。机坪面积6600平方米，停机位2个，航站楼面积1900平方米。位于城东6公里处，民航路与金鹰大道交叉处。\\n航点\\n客服电话\\n昭通机场客服电话：0870-2830004',\n",
       "  'source': 'wikipedia.zh2307'})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T13:54:21.002827Z",
     "start_time": "2025-05-21T13:53:55.266150Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def process_function(examples, tokenizer=tokenizer):\n",
    "    return tokenizer(examples[\"completion\"], max_length=384, truncation=True)\n",
    "ds = ds.map(process_function, batched=True, num_proc=16, batch_size=64, remove_columns=ds[\"train\"].column_names)"
   ],
   "id": "af29533ff3cd41b3",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map (num_proc=16): 100%|██████████| 254547/254547 [00:25<00:00, 9940.13 examples/s] \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'token_type_ids', 'attention_mask'],\n",
       "        num_rows: 254547\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T13:59:36.482590Z",
     "start_time": "2025-05-21T13:59:14.429921Z"
    }
   },
   "cell_type": "code",
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"output/0207/\",\n",
    "    per_device_train_batch_size=1,\n",
    "    gradient_accumulation_steps=16,\n",
    "    max_steps=30,\n",
    "    logging_steps=10\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=ds[\"train\"],\n",
    "    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),\n",
    ")\n",
    "\n",
    "trainer.train()"
   ],
   "id": "5947e3fff3cf1ca7",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ],
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [30/30 00:20, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>1.415300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>1.444100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>1.384600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=30, training_loss=1.414657211303711, metrics={'train_runtime': 21.8343, 'train_samples_per_second': 21.984, 'train_steps_per_second': 1.374, 'total_flos': 79365540805008.0, 'train_loss': 1.414657211303711, 'epoch': 0.001885702836804205})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 15
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T14:02:21.788012Z",
     "start_time": "2025-05-21T14:02:21.734827Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from transformers import pipeline\n",
    "pipe = pipeline(\"fill-mask\", model=model, tokenizer=tokenizer)\n"
   ],
   "id": "cccc5931305e5635",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Device set to use cuda:0\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T14:02:22.208935Z",
     "start_time": "2025-05-21T14:02:22.031541Z"
    }
   },
   "cell_type": "code",
   "source": "pipe(\"西安交通[MASK][MASK]博物馆（Xi'an Jiaotong University Museum）是一座位于西安交通大学的博物馆\")\n",
   "id": "541ae9d096476bdb",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'score': 0.9973401427268982,\n",
       "   'token': 1920,\n",
       "   'token_str': '大',\n",
       "   'sequence': \"[CLS] 西 安 交 通 大 [MASK] 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 0.0011449242010712624,\n",
       "   'token': 2110,\n",
       "   'token_str': '学',\n",
       "   'sequence': \"[CLS] 西 安 交 通 学 [MASK] 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 0.0002696097071748227,\n",
       "   'token': 4906,\n",
       "   'token_str': '科',\n",
       "   'sequence': \"[CLS] 西 安 交 通 科 [MASK] 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 0.00013024771760683507,\n",
       "   'token': 5466,\n",
       "   'token_str': '职',\n",
       "   'sequence': \"[CLS] 西 安 交 通 职 [MASK] 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 9.580811456544325e-05,\n",
       "   'token': 2339,\n",
       "   'token_str': '工',\n",
       "   'sequence': \"[CLS] 西 安 交 通 工 [MASK] 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"}],\n",
       " [{'score': 0.9981729984283447,\n",
       "   'token': 2110,\n",
       "   'token_str': '学',\n",
       "   'sequence': \"[CLS] 西 安 交 通 [MASK] 学 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 0.0004937463090755045,\n",
       "   'token': 1920,\n",
       "   'token_str': '大',\n",
       "   'sequence': \"[CLS] 西 安 交 通 [MASK] 大 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 0.0004145340935792774,\n",
       "   'token': 7368,\n",
       "   'token_str': '院',\n",
       "   'sequence': \"[CLS] 西 安 交 通 [MASK] 院 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 9.150406549451873e-05,\n",
       "   'token': 3318,\n",
       "   'token_str': '术',\n",
       "   'sequence': \"[CLS] 西 安 交 通 [MASK] 术 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"},\n",
       "  {'score': 8.517364040017128e-05,\n",
       "   'token': 3413,\n",
       "   'token_str': '校',\n",
       "   'sequence': \"[CLS] 西 安 交 通 [MASK] 校 博 物 馆 （ xi ' an jiaotong university museum ） 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]\"}]]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 17
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T14:02:27.275550Z",
     "start_time": "2025-05-21T14:02:27.186238Z"
    }
   },
   "cell_type": "code",
   "source": "pipe(\"下面是一则[MASK][MASK]新闻。小编报道，近日，游戏产业发展的非常好！\")\n",
   "id": "e5c4f8f57020260",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'score': 0.14391231536865234,\n",
       "   'token': 4685,\n",
       "   'token_str': '相',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 相 [MASK] 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.10960014164447784,\n",
       "   'token': 3952,\n",
       "   'token_str': '游',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 游 [MASK] 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.056944213807582855,\n",
       "   'token': 3173,\n",
       "   'token_str': '新',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 新 [MASK] 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.029918836429715157,\n",
       "   'token': 2031,\n",
       "   'token_str': '娱',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 娱 [MASK] 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.029363568872213364,\n",
       "   'token': 3297,\n",
       "   'token_str': '最',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 最 [MASK] 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'}],\n",
       " [{'score': 0.1781068742275238,\n",
       "   'token': 1068,\n",
       "   'token_str': '关',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 [MASK] 关 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.04587458819150925,\n",
       "   'token': 7481,\n",
       "   'token_str': '面',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 [MASK] 面 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.045657407492399216,\n",
       "   'token': 2767,\n",
       "   'token_str': '戏',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 [MASK] 戏 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.03221156820654869,\n",
       "   'token': 2141,\n",
       "   'token_str': '实',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 [MASK] 实 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'},\n",
       "  {'score': 0.030621333047747612,\n",
       "   'token': 6380,\n",
       "   'token_str': '讯',\n",
       "   'sequence': '[CLS] 下 面 是 一 则 [MASK] 讯 新 闻 。 小 编 报 道 ， 近 日 ， 游 戏 产 业 发 展 的 非 常 好 ！ [SEP]'}]]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 18
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "25942c0207a6b5b3"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
